-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable tuning in Batch norm CK solver #3326
Conversation
IsApplicable/GetSolution in general should throw instead of asserting for several reasons:
assert(i < vec.size());
auto x = vec[i]; With |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
@@ -240,7 +240,8 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex | |||
|
|||
fmha_runtime_args.p_drop = probability; | |||
fmha_runtime_args.drop_seed_offset = | |||
std::make_pair(dataFwd.dropoutSeedData, dataFwd.dropoutOffsetData); | |||
std::make_pair(reinterpret_cast<uint64_t>(dataFwd.dropoutSeedData), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CAHEK7 had to this to silence the compiler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to double check - are you going to convert gpu pointer to integer? Is the code using fmha_runtime_args.drop_seed_offset
aware that it's actually gpu pointer?
For me that reinterpret_cast
doesn't look like a correct solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! its a GPU pointer but I think for now we have disabled the dropout https://github.com/ROCm/MIOpen/blob/develop/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp#L234
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to make it explicitly 0 and put the same comment as
MIOpen/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp
Lines 237 to 239 in 0c24160
// TODO : Change API to take in probability value as host side value instead of device | |
// pointer to match CK API. Calling a blocking hipMemcpy will cause issues with stream, | |
// and isn't async. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks some improvements can be done here. But it will work even without any changes.
The only thing which really concerns me is reinterpret_cast
in src/solver/mha/mha_ck_fa_v2_solver_forward.cpp
- it doesn't look like a correct code.
@bghimireamd Windows builds are failing, the error messages look like something caused by this PR, could you take a look? |
This PR only enabling tuning in CK solver.